import asyncio
import io
import logging
import time
from collections import deque
from dataclasses import dataclass
from functools import wraps
from inspect import isasyncgenfunction, iscoroutinefunction
from typing import (
    Any,
    AsyncGenerator,
    Callable,
    Coroutine,
    Dict,
    Generic,
    Iterable,
    List,
    Literal,
    Optional,
    Protocol,
    Set,
    Tuple,
    TypeVar,
    overload,
)

from ray import serve
from ray._common.signature import extract_signature, flatten_args, recover_args
from ray._common.utils import get_or_create_event_loop
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve._private.utils import extract_self_if_method_call
from ray.serve.exceptions import RayServeException
from ray.util.annotations import PublicAPI

logger = logging.getLogger(SERVE_LOGGER_NAME)


# The user can return these values in their streaming batch handler function to
# indicate that a request is finished, so Serve can terminate the request.
USER_CODE_STREAMING_SENTINELS = [StopIteration, StopAsyncIteration]


@dataclass
class _SingleRequest:
    self_arg: Any
    flattened_args: List[Any]
    future: asyncio.Future
    request_context: serve.context._RequestContext


@dataclass
class _GeneratorResult:
    result: Any
    next_future: asyncio.Future


@dataclass
class _RuntimeSummaryStatistics:
    start_times: List[float]

    @property
    def min_start_time(self) -> Optional[float]:
        return min(self.start_times) if self.start_times else None

    @property
    def mean_start_time(self) -> Optional[float]:
        return (
            sum(self.start_times) / len(self.start_times) if self.start_times else None
        )

    @property
    def max_start_time(self) -> Optional[float]:
        return max(self.start_times) if self.start_times else None

    @property
    def num_requests(self) -> int:
        return len(self.start_times)


def _batch_args_kwargs(
    list_of_flattened_args: List[List[Any]],
) -> Tuple[Tuple[Any], Dict[Any, Any]]:
    """Batch a list of flatten args and returns regular args and kwargs"""
    # Ray's flatten arg format is a list with alternating key and values
    # e.g. args=(1, 2), kwargs={"key": "val"} got turned into
    #      [None, 1, None, 2, "key", "val"]
    arg_lengths = {len(args) for args in list_of_flattened_args}
    assert (
        len(arg_lengths) == 1
    ), "All batch requests should have the same number of parameters."
    arg_length = arg_lengths.pop()

    batched_flattened_args = []
    for idx in range(arg_length):
        if idx % 2 == 0:
            batched_flattened_args.append(list_of_flattened_args[0][idx])
        else:
            batched_flattened_args.append(
                [item[idx] for item in list_of_flattened_args]
            )

    return recover_args(batched_flattened_args)


class _BatchQueue:
    def __init__(
        self,
        max_batch_size: int,
        batch_wait_timeout_s: float,
        max_concurrent_batches: int,
        handle_batch_func: Optional[Callable] = None,
        batch_size_fn: Optional[Callable[[List], int]] = None,
    ) -> None:
        """Async queue that accepts individual items and returns batches.

        Respects max_batch_size and batch_wait_timeout_s; a batch will be returned when
        max_batch_size elements are available or the timeout has passed since
        the previous get.

        If handle_batch_func is passed in, a background coroutine will run to
        poll from the queue and call handle_batch_func on the results.

        Cannot be pickled.

        Arguments:
            max_batch_size: max number of elements to return in a batch.
            batch_wait_timeout_s: time to wait before returning an incomplete
                batch.
            max_concurrent_batches: max number of batches to run concurrently.
            handle_batch_func(Optional[Callable]): callback to run in the
                background to handle batches if provided.
            batch_size_fn(Optional[Callable[[List], int]]): optional function to
                compute the effective batch size. If None, uses len(batch).
                The function takes a list of requests and returns an integer
                representing the batch size. This is useful for batching based
                on custom metrics such as total nodes in graphs, total tokens
                in sequences, etc.
        """
        self.queue: asyncio.Queue[_SingleRequest] = asyncio.Queue()
        self.max_batch_size = max_batch_size
        self.batch_wait_timeout_s = batch_wait_timeout_s
        self.max_concurrent_batches = max_concurrent_batches
        self.batch_size_fn = batch_size_fn
        self.semaphore = asyncio.Semaphore(max_concurrent_batches)
        self.requests_available_event = asyncio.Event()
        self.tasks: Set[asyncio.Task] = set()

        # Used for observability.
        self.curr_iteration_start_times: Dict[asyncio.Task, float] = {}

        self._handle_batch_task = None
        self._loop = get_or_create_event_loop()
        if handle_batch_func is not None:
            self._handle_batch_task = self._loop.create_task(
                self._process_batches(handle_batch_func)
            )
        self._warn_if_max_batch_size_exceeds_max_ongoing_requests()

    def _warn_if_max_batch_size_exceeds_max_ongoing_requests(self):
        """Helper to check whether the max_batch_size is bounded.

        Log a warning to configure `max_ongoing_requests` if it's bounded.
        """
        max_ongoing_requests = (
            serve.get_replica_context()._deployment_config.max_ongoing_requests
        )
        if max_ongoing_requests < self.max_batch_size * self.max_concurrent_batches:
            logger.warning(
                f"`max_batch_size` ({self.max_batch_size}) * `max_concurrent_batches` "
                f"({self.max_concurrent_batches}) is larger than `max_ongoing_requests` "
                f"({max_ongoing_requests}). This means the replica will never achieve "
                "the configured `max_batch_size` concurrently. Please update "
                "`max_ongoing_requests` to be >= `max_batch_size` * `max_concurrent_batches`."
            )

    def set_max_batch_size(self, new_max_batch_size: int) -> None:
        """Updates queue's max_batch_size."""
        self.max_batch_size = new_max_batch_size
        self._warn_if_max_batch_size_exceeds_max_ongoing_requests()

    def put(self, request: Tuple[_SingleRequest, asyncio.Future]) -> None:
        self.queue.put_nowait(request)
        self.requests_available_event.set()

    def _compute_batch_size(self, batch: List[_SingleRequest]) -> int:
        """Compute the effective batch size using batch_size_fn or len()."""
        if self.batch_size_fn is None:
            return len(batch)

        # Extract the actual data items from requests to pass to batch_size_fn.
        # We need to reconstruct the original arguments from flattened_args.
        items = []
        for request in batch:
            # Recover the original arguments from flattened format
            args, kwargs = recover_args(request.flattened_args)
            # The batch function expects a single positional argument (the item)
            # after 'self' has been extracted (if it was a method)
            items.append(args[0])

        return self.batch_size_fn(items)

    async def wait_for_batch(self) -> List[_SingleRequest]:
        """Wait for batch respecting self.max_batch_size and self.timeout_s.

        Returns a batch of up to self.max_batch_size items. Waits for up to
        to self.timeout_s after receiving the first request that will be in
        the next batch. After the timeout, returns as many items as are ready.

        Always returns a batch with at least one item - will block
        indefinitely until an item comes in.
        """

        batch = []
        first_item = await self.queue.get()  # Block until first item arrives

        # Cache current max_batch_size and batch_wait_timeout_s for this batch.
        max_batch_size = self.max_batch_size
        batch_wait_timeout_s = self.batch_wait_timeout_s

        # Check if first item alone exceeds max_batch_size (only with batch_size_fn)
        if self.batch_size_fn is not None:
            first_item_size = self._compute_batch_size([first_item])
            if first_item_size > max_batch_size:
                exc = RuntimeError(
                    "Size of item is greater than max_batch_size. "
                    "Please increase the max_batch_size or check the "
                    "implementation of the batch_size_fn."
                )
                # Set exception on the future so the caller receives it
                first_item.future.set_exception(exc)
                return []

        batch.append(first_item)

        # Wait self.timeout_s seconds for new queue arrivals.
        batch_start_time = time.time()
        while True:
            remaining_batch_time_s = max(
                batch_wait_timeout_s - (time.time() - batch_start_time), 0
            )
            try:
                # Wait for new arrivals.
                await asyncio.wait_for(
                    self.requests_available_event.wait(), remaining_batch_time_s
                )
            except asyncio.TimeoutError:
                pass

            # Custom batch size function logic
            if self.batch_size_fn is not None:
                # Add all new arrivals to the batch.
                # Track items we need to put back if they don't fit
                deferred_item = None
                while not self.queue.empty():
                    next_item = self.queue.get_nowait()
                    # Temporarily add to check size
                    batch.append(next_item)
                    new_size = self._compute_batch_size(batch)

                    if new_size > max_batch_size:
                        # Would exceed limit, remove it and save for later
                        batch.pop()
                        deferred_item = next_item
                        break
                    # Size is OK, keep it in the batch (already added above)

                # Put deferred item back in queue for next batch
                if deferred_item is not None:
                    # NOTE: The deferred item goes to the back of the queue (FIFO),
                    # so newer requests may be processed before it. Consider using
                    # asyncio.PriorityQueue if strict ordering is required.
                    self.queue.put_nowait(deferred_item)
                    # break the loop early because the deferred item is too large to fit in the batch
                    break
            else:
                # Default behavior: use original len() check logic
                while len(batch) < max_batch_size and not self.queue.empty():
                    batch.append(self.queue.get_nowait())

            # Only clear the put event if the queue is empty. If it's not empty
            # we can start constructing a new batch immediately in the next loop.
            # The code that puts items into the queue runs on the same event loop
            # as this code, so there's no race condition between the time we
            # get objects in the queue (and clear the event) and when objects
            # get added to the queue.
            if self.queue.empty():
                self.requests_available_event.clear()

            current_batch_size = self._compute_batch_size(batch)
            if (
                time.time() - batch_start_time >= batch_wait_timeout_s
                or current_batch_size >= max_batch_size
            ):
                break

        return batch

    def _validate_results(
        self, results: Iterable[Any], input_batch_length: int
    ) -> None:
        if len(results) != input_batch_length:
            raise RayServeException(
                "Batched function doesn't preserve batch size. "
                f"The input list has length {input_batch_length} but the "
                f"returned list has length {len(results)}."
            )

    async def _consume_func_generator(
        self,
        func_generator: AsyncGenerator,
        initial_futures: List[asyncio.Future],
        input_batch_length: int,
    ) -> None:
        """Consumes batch function generator.

        This function only runs if the function decorated with @serve.batch
        is a generator.
        """

        FINISHED_TOKEN = None

        try:
            futures = deque(initial_futures)
            assert len(futures) == input_batch_length

            async for results in func_generator:
                self._validate_results(results, input_batch_length)
                for idx in range(input_batch_length):
                    result, future = results[idx], futures[0]

                    if future is FINISHED_TOKEN:
                        # This caller has already terminated.
                        futures.append(FINISHED_TOKEN)
                    elif result in USER_CODE_STREAMING_SENTINELS:
                        # User's code returned sentinel. No values left
                        # for caller. Terminate iteration for caller.
                        _set_exception_if_not_done(future, StopAsyncIteration)
                        futures.append(FINISHED_TOKEN)
                    else:
                        next_future = get_or_create_event_loop().create_future()
                        _set_result_if_not_done(
                            future, _GeneratorResult(result, next_future)
                        )
                        futures.append(next_future)

                    # Remove processed future. We remove the future at the very
                    # end of the loop to ensure that if an exception occurs,
                    # all pending futures will get set in the `except` block.
                    futures.popleft()

            for future in futures:
                if future is not FINISHED_TOKEN:
                    _set_exception_if_not_done(future, StopAsyncIteration)
        except Exception as e:
            for future in futures:
                if future is not FINISHED_TOKEN:
                    _set_exception_if_not_done(future, e)

    async def _assign_func_results(
        self,
        func_future: asyncio.Future,
        futures: List[asyncio.Future],
        input_batch_length: int,
    ):
        """Assigns func's results to the list of futures."""

        try:
            results = await func_future
            self._validate_results(results, input_batch_length)
            for result, future in zip(results, futures):
                _set_result_if_not_done(future, result)
        except Exception as e:
            for future in futures:
                _set_exception_if_not_done(future, e)

    async def _process_batches(self, func: Callable) -> None:
        """Loops infinitely and processes queued request batches."""
        # When asyncio task is created, the task will inherit the request context from the current context.
        # So we unset the request context so the current context is not inherited by the task, _process_batch.
        serve.context._unset_request_context()
        while not self._loop.is_closed():
            batch = await self.wait_for_batch()
            promise = self._process_batch(func, batch)
            task = asyncio.create_task(promise)
            self.tasks.add(task)
            self.curr_iteration_start_times[task] = time.time()
            task.add_done_callback(self._handle_completed_task)

    async def _process_batch(self, func: Callable, batch: List[_SingleRequest]) -> None:
        """Processes queued request batch."""
        # NOTE: this semaphore caps the number of concurrent batches specified by `max_concurrent_batches`
        async with self.semaphore:
            # Remove requests that have been cancelled from the batch. If
            # all requests have been cancelled, simply return and wait for
            # the next batch.
            batch = [req for req in batch if not req.future.cancelled()]
            if len(batch) == 0:
                return

            futures = [item.future for item in batch]

            # Most of the logic in the function should be wrapped in this try-
            # except block, so the futures' exceptions can be set if an exception
            # occurs. Otherwise, the futures' requests may hang indefinitely.
            try:
                self_arg = batch[0].self_arg
                args, kwargs = _batch_args_kwargs(
                    [item.flattened_args for item in batch]
                )

                # Method call.
                if self_arg is not None:
                    func_future_or_generator = func(self_arg, *args, **kwargs)
                # Normal function call.
                else:
                    func_future_or_generator = func(*args, **kwargs)

                # Add individual request context to the batch request context
                serve.context._set_batch_request_context(
                    [req.request_context for req in batch]
                )

                if isasyncgenfunction(func):
                    func_generator = func_future_or_generator
                    await self._consume_func_generator(
                        func_generator, futures, len(batch)
                    )
                else:
                    func_future = func_future_or_generator
                    await self._assign_func_results(func_future, futures, len(batch))

                # Reset the batch request context after the batch is processed
                serve.context._set_batch_request_context([])
            except Exception as e:
                logger.exception("_process_batch ran into an unexpected exception.")

                for future in futures:
                    _set_exception_if_not_done(future, e)

    def _handle_completed_task(self, task: asyncio.Task) -> None:
        self.tasks.remove(task)
        del self.curr_iteration_start_times[task]
        self._log_if_exception(task.exception())

    @staticmethod
    def _log_if_exception(exception_maybe: Optional[BaseException]) -> None:
        if exception_maybe is not None:
            if isinstance(exception_maybe, asyncio.CancelledError):
                logger.debug("Task was cancelled")
            else:
                logger.exception("Task failed unexpectedly")

    def __del__(self):
        if (
            self._handle_batch_task is None
            or not get_or_create_event_loop().is_running()
        ):
            return

        # TODO(edoakes): although we try to gracefully shutdown here, it still
        # causes some errors when the process exits due to the asyncio loop
        # already being destroyed.
        self._handle_batch_task.cancel()


class _LazyBatchQueueWrapper:
    """Stores a _BatchQueue and updates its settings.

    _BatchQueue cannot be pickled, you must construct it lazily
    at runtime inside a replica. This class initializes a queue only upon
    first access.
    """

    def __init__(
        self,
        max_batch_size: int = 10,
        batch_wait_timeout_s: float = 0.0,
        max_concurrent_batches: int = 1,
        handle_batch_func: Optional[Callable] = None,
        batch_size_fn: Optional[Callable[[List], int]] = None,
    ):
        self._queue: Optional[_BatchQueue] = None
        self.max_batch_size = max_batch_size
        self.batch_wait_timeout_s = batch_wait_timeout_s
        self.max_concurrent_batches = max_concurrent_batches
        self.handle_batch_func = handle_batch_func
        self.batch_size_fn = batch_size_fn

    @property
    def queue(self) -> _BatchQueue:
        """Returns _BatchQueue.

        Initializes queue when called for the first time.
        """
        if self._queue is None:
            self._queue = _BatchQueue(
                self.max_batch_size,
                self.batch_wait_timeout_s,
                self.max_concurrent_batches,
                self.handle_batch_func,
                self.batch_size_fn,
            )
        return self._queue

    def set_max_batch_size(self, new_max_batch_size: int) -> None:
        """Updates queue's max_batch_size."""

        self.max_batch_size = new_max_batch_size

        if self._queue is not None:
            self._queue.set_max_batch_size(new_max_batch_size)

    def set_batch_wait_timeout_s(self, new_batch_wait_timeout_s: float) -> None:
        self.batch_wait_timeout_s = new_batch_wait_timeout_s

        if self._queue is not None:
            self._queue.batch_wait_timeout_s = new_batch_wait_timeout_s

    def get_max_batch_size(self) -> int:
        return self.max_batch_size

    def get_batch_wait_timeout_s(self) -> float:
        return self.batch_wait_timeout_s

    def _get_curr_iteration_start_times(self) -> _RuntimeSummaryStatistics:
        """Gets summary statistics of current iteration's start times."""
        return _RuntimeSummaryStatistics(
            list(self.queue.curr_iteration_start_times.values())
        )

    async def _is_batching_task_alive(self) -> bool:
        """Gets whether default _BatchQueue's background task is alive.

        Returns False if the batch handler doesn't use a default _BatchQueue.
        """

        if hasattr(self.queue, "_handle_batch_task"):
            return not self.queue._handle_batch_task.done()
        else:
            return False

    async def _get_handling_task_stack(self) -> Optional[str]:
        """Gets the stack for the default _BatchQueue's background task.

        Returns empty string if the batch handler doesn't use a default _BatchQueue.
        """

        if hasattr(self.queue, "_handle_batch_task"):
            str_buffer = io.StringIO()
            self.queue._handle_batch_task.print_stack(file=str_buffer)
            return str_buffer.getvalue()
        else:
            return None


def _validate_max_batch_size(max_batch_size):
    if not isinstance(max_batch_size, int):
        if isinstance(max_batch_size, float) and max_batch_size.is_integer():
            max_batch_size = int(max_batch_size)
        else:
            raise TypeError(
                f"max_batch_size must be integer >= 1, got {max_batch_size}"
            )

    if max_batch_size < 1:
        raise ValueError(
            f"max_batch_size must be an integer >= 1, got {max_batch_size}"
        )


def _validate_batch_wait_timeout_s(batch_wait_timeout_s):
    if not isinstance(batch_wait_timeout_s, (float, int)):
        raise TypeError(
            f"batch_wait_timeout_s must be a float >= 0, got {batch_wait_timeout_s}"
        )

    if batch_wait_timeout_s < 0:
        raise ValueError(
            f"batch_wait_timeout_s must be a float >= 0, got {batch_wait_timeout_s}"
        )


def _validate_max_concurrent_batches(max_concurrent_batches: int) -> None:
    if not isinstance(max_concurrent_batches, int) or max_concurrent_batches < 1:
        raise TypeError(
            f"max_concurrent_batches must be an integer >= 1, got {max_concurrent_batches}"
        )


def _validate_batch_size_fn(batch_size_fn: Optional[Callable[[List], int]]) -> None:
    if batch_size_fn is not None and not callable(batch_size_fn):
        raise TypeError(
            f"batch_size_fn must be a callable or None, got {type(batch_size_fn)}"
        )


SelfType = TypeVar("SelfType", contravariant=True)
T = TypeVar("T")
R = TypeVar("R")


class _SyncBatchingMethod(Protocol, Generic[SelfType, T, R]):
    def __call__(self, self_: SelfType, __batch: List[T], /) -> List[R]:
        ...


class _AsyncBatchingMethod(Protocol, Generic[SelfType, T, R]):
    async def __call__(self, self_: SelfType, __batch: List[T], /) -> List[R]:
        ...


@overload  # Sync function for `batch` called WITHOUT arguments
def batch(_sync_func: Callable[[List[T]], List[R]], /) -> Callable[[T], R]:
    ...


@overload  # Async function for `batch` called WITHOUT arguments
def batch(
    _async_func: Callable[[List[T]], Coroutine[Any, Any, List[R]]], /
) -> Callable[[T], Coroutine[Any, Any, R]]:
    ...


@overload  # Sync method for `batch` called WITHOUT arguments
def batch(
    _sync_meth: _SyncBatchingMethod[SelfType, T, R], /
) -> Callable[[SelfType, T], R]:
    ...


@overload  # Async method for `batch` called WITHOUT arguments
def batch(
    _async_meth: _AsyncBatchingMethod[SelfType, T, R], /
) -> Callable[[SelfType, T], Coroutine[Any, Any, R]]:
    ...


@overload  # `batch` called WITH arguments
def batch(
    _: Literal[None] = None,
    /,
    max_batch_size: int = 10,
    batch_wait_timeout_s: float = 0.01,
    max_concurrent_batches: int = 1,
    batch_size_fn: Optional[Callable[[List], int]] = None,
) -> "_BatchDecorator":
    ...


class _BatchDecorator(Protocol):
    """Descibes behaviour of decorator produced by calling `batch` with arguments"""

    @overload  # Sync function
    def __call__(self, _sync_func: Callable[[List[T]], List[R]], /) -> Callable[[T], R]:
        ...

    @overload  # Async function
    def __call__(
        self, _async_func: Callable[[List[T]], Coroutine[Any, Any, List[R]]], /
    ) -> Callable[[T], Coroutine[Any, Any, R]]:
        ...

    @overload  # Sync method
    def __call__(
        self, _sync_meth: _SyncBatchingMethod[SelfType, T, R], /
    ) -> Callable[[SelfType, T], R]:
        ...

    @overload  # Async method
    def __call__(
        self, _async_meth: _AsyncBatchingMethod[SelfType, T, R], /
    ) -> Callable[[SelfType, T], Coroutine[Any, Any, R]]:
        ...


@PublicAPI(stability="stable")
def batch(
    _func: Optional[Callable] = None,
    /,
    max_batch_size: int = 10,
    batch_wait_timeout_s: float = 0.01,
    max_concurrent_batches: int = 1,
    batch_size_fn: Optional[Callable[[List], int]] = None,
) -> Callable:
    """Converts a function to asynchronously handle batches.

    The function can be a standalone function or a class method. In both
    cases, the function must be `async def` and take a list of objects as
    its sole argument and return a list of the same length as a result.

    When invoked, the caller passes a single object. These will be batched
    and executed asynchronously once there is a batch of `max_batch_size`
    or `batch_wait_timeout_s` has elapsed, whichever occurs first.

    `max_batch_size` and `batch_wait_timeout_s` can be updated using setter
    methods from the batch_handler (`set_max_batch_size` and
    `set_batch_wait_timeout_s`).

    Example:

    .. code-block:: python

            from ray import serve
            from starlette.requests import Request

            @serve.deployment
            class BatchedDeployment:
                @serve.batch(max_batch_size=10, batch_wait_timeout_s=0.1)
                async def batch_handler(self, requests: List[Request]) -> List[str]:
                    response_batch = []
                    for r in requests:
                        name = (await requests.json())["name"]
                        response_batch.append(f"Hello {name}!")

                    return response_batch

                def update_batch_params(self, max_batch_size, batch_wait_timeout_s):
                    self.batch_handler.set_max_batch_size(max_batch_size)
                    self.batch_handler.set_batch_wait_timeout_s(batch_wait_timeout_s)

                async def __call__(self, request: Request):
                    return await self.batch_handler(request)

            app = BatchedDeployment.bind()

    Arguments:
        max_batch_size: the maximum batch size that will be executed in
            one call to the underlying function.
        batch_wait_timeout_s: the maximum duration to wait for
            `max_batch_size` elements before running the current batch.
        max_concurrent_batches: the maximum number of batches that can be
            executed concurrently. If the number of concurrent batches exceeds
            this limit, the batch handler will wait for a batch to complete
            before sending the next batch to the underlying function.
        batch_size_fn: optional function to compute the effective batch size.
            If provided, this function takes a list of items and returns an
            integer representing the batch size. This is useful for batching
            based on custom metrics such as total nodes in graphs, total tokens
            in sequences, or other domain-specific measures. If None, the batch
            size is computed as len(batch).
    """
    # `_func` will be None in the case when the decorator is parametrized.
    # See the comment at the end of this function for a detailed explanation.
    if _func is not None:
        if not callable(_func):
            raise TypeError(
                "@serve.batch can only be used to decorate functions or methods."
            )

        if not iscoroutinefunction(_func):
            raise TypeError("Functions decorated with @serve.batch must be 'async def'")

    _validate_max_batch_size(max_batch_size)
    _validate_batch_wait_timeout_s(batch_wait_timeout_s)
    _validate_max_concurrent_batches(max_concurrent_batches)
    _validate_batch_size_fn(batch_size_fn)

    def _batch_decorator(_func):
        lazy_batch_queue_wrapper = _LazyBatchQueueWrapper(
            max_batch_size,
            batch_wait_timeout_s,
            max_concurrent_batches,
            _func,
            batch_size_fn,
        )

        async def batch_handler_generator(
            first_future: asyncio.Future,
        ) -> AsyncGenerator:
            """Generator that handles generator batch functions."""

            future = first_future
            while True:
                try:
                    async_response: _GeneratorResult = await future
                    future = async_response.next_future
                    yield async_response.result
                except StopAsyncIteration:
                    break

        def enqueue_request(args, kwargs) -> asyncio.Future:
            flattened_args: List = flatten_args(extract_signature(_func), args, kwargs)

            # If the function is a method, remove self as an argument.
            self = extract_self_if_method_call(args, _func)
            if self is not None:
                flattened_args = flattened_args[2:]

            batch_queue = lazy_batch_queue_wrapper.queue

            future = get_or_create_event_loop().create_future()
            request_context = serve.context._get_serve_request_context()
            batch_queue.put(
                _SingleRequest(self, flattened_args, future, request_context)
            )
            return future

        @wraps(_func)
        def generator_batch_wrapper(*args, **kwargs):
            first_future = enqueue_request(args, kwargs)
            return batch_handler_generator(first_future)

        @wraps(_func)
        async def batch_wrapper(*args, **kwargs):
            # This will raise if the underlying call raised an exception.
            return await enqueue_request(args, kwargs)

        if isasyncgenfunction(_func):
            wrapper = generator_batch_wrapper
        else:
            wrapper = batch_wrapper

        # We store the lazy_batch_queue_wrapper's getters and setters as
        # batch_wrapper attributes, so they can be accessed in user code.
        wrapper._get_max_batch_size = lazy_batch_queue_wrapper.get_max_batch_size
        wrapper._get_batch_wait_timeout_s = (
            lazy_batch_queue_wrapper.get_batch_wait_timeout_s
        )
        wrapper.set_max_batch_size = lazy_batch_queue_wrapper.set_max_batch_size
        wrapper.set_batch_wait_timeout_s = (
            lazy_batch_queue_wrapper.set_batch_wait_timeout_s
        )

        # Store debugging methods in the lazy_batch_queue wrapper
        wrapper._get_curr_iteration_start_times = (
            lazy_batch_queue_wrapper._get_curr_iteration_start_times
        )
        wrapper._is_batching_task_alive = (
            lazy_batch_queue_wrapper._is_batching_task_alive
        )
        wrapper._get_handling_task_stack = (
            lazy_batch_queue_wrapper._get_handling_task_stack
        )

        return wrapper

    # Unfortunately, this is required to handle both non-parametrized
    # (@serve.batch) and parametrized (@serve.batch(**kwargs)) usage.
    # In the former case, `serve.batch` will be called with the underlying
    # function as the sole argument. In the latter case, it will first be
    # called with **kwargs, then the result of that call will be called
    # with the underlying function as the sole argument (i.e., it must be a
    # "decorator factory.").
    return _batch_decorator(_func) if callable(_func) else _batch_decorator


def _set_result_if_not_done(future: asyncio.Future, result: Any):
    """Sets the future's result if the future is not done."""

    if not future.done():
        future.set_result(result)


def _set_exception_if_not_done(future: asyncio.Future, exception: Any):
    """Sets the future's exception if the future is not done."""

    if not future.done():
        future.set_exception(exception)
